import numpy as np
import torch

import pandas as pd
from torch import optim, nn
from torch.distributions import RelaxedOneHotCategorical, OneHotCategorical
from torch.utils.data import DataLoader

from Image_Mediator_Training.imageMediator_graph import set_imageMediator
from torch.nn import functional as F

from ModularUtils.ControllerConstants import get_multiple_labels_fill, map_dictfill_to_discrete
from ModularUtils.Discriminators import DigitImageDiscriminator
from ModularUtils.Experiment_Class import Experiment
from ModularUtils.FrontBackDoorCalculation import estiamte_ate_backdoor_direct
from ModularUtils.FunctionsConstant import get_Imagedataset, get_dataset, save_datasets
from ModularUtils.FunctionsDistribution import compare_conditionals_within
from ModularUtils.FunctionsTraining import calc_gradient_penalty, labels_image_gradient_penalty
from ModularUtils.Generators import ConditionalClassifier, classifierCritic

U_dim = 4
pC_dim = 3
midC_dim = 2
input_dim = U_dim + pC_dim

def init_weights(m):  # for generator and discriminator, they are initialized inside the class
    if isinstance(m, torch.nn.Linear):
        torch.nn.init.xavier_uniform(m.weight)
        m.bias.data.fill_(0.01)


class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()


        self.fc1 = nn.Linear(input_dim, 100)
        self.fc2 = nn.Linear(100, midC_dim)
        self.fc3 = nn.Linear(midC_dim, 100)
        self.fc4 = nn.Linear(100, U_dim+ pC_dim)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return F.softmax(self.fc2(h1).view(len(x), midC_dim), -1)

    def reparameterize(self, p):
        if self.training:
            # At training time we sample from a relaxed Gumbel-Softmax Distribution. The samples are continuous but when we increase the temperature the samples gets closer to a Categorical.
            TEMPERATURE = 1
            m = RelaxedOneHotCategorical(TEMPERATURE, p)
            return m.rsample()
        else:
            # At testing time we sample from a Categorical Distribution.
            m = OneHotCategorical(p)
            return m.sample()

    def decode(self, z):
        h3 = F.relu(self.fc3(z.view(len(z), midC_dim)))
        # return F.sigmoid(self.fc4(h3))
        return self.fc4(h3)

    def forward(self, x):

        x= get_multiple_labels_fill(Exp, x, [U_dim, pC_dim], isImage_labels=False)

        p = self.encode(x.view(-1, input_dim))
        z = self.reparameterize(p)
        return self.decode(z), p


# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, p):
    x = get_multiple_labels_fill(Exp, x, [U_dim, pC_dim], isImage_labels=False)
    # BCE = F.binary_cross_entropy(recon_x, x, size_average=False)
    mse1 =  F.mse_loss(recon_x[0:4], x[0:4], reduction='mean')
    mse2 =  F.mse_loss(recon_x[4:6], x[4:6], reduction='mean')

    mse= mse1+ mse2
    # If the prior is the uniform distribution, the KL is simply the entropy (ignoring the constant equals to log d with d the dimensions of the categorical distribution). We can use the entropy of the categorical distribution or of the entrop y of the gumbel-softmax distribution. Here for simplicity we use the entropy of the categorical distribution.
    # KLD = - torch.sum(p*torch.log(p + 1e-6))
    KLD = 0

    # return BCE + KLD
    print("mse", mse)
    return mse + KLD

def train_encoder(Exp, U, C):
    model = VAE().to(Exp.DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    dataset= torch.cat([U,C],1).to(Exp.DEVICE)
    # dataset= U.to(Exp.DEVICE)
    batch_size= 100
    train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=True)

    epochs=1000
    for epoch in range(1, epochs + 1):

        model.train()
        train_loss = 0
        for batch_idx, data in enumerate(train_loader):
            data = data.to(Exp.DEVICE)

            freq1 = torch.bincount(data[:, 0].type(torch.LongTensor), minlength=4)/data.shape[0]
            freq2 = torch.bincount(data[:, 1].type(torch.LongTensor), minlength=2)/data.shape[0]

            optimizer.zero_grad()
            recon_batch, p = model(data)

            # recon_disc = torch.tensor(map_dictfill_to_discrete(Exp, {'U':recon_batch[:,0:4]}, ['U0']))
            recon_disc = torch.tensor(map_dictfill_to_discrete(Exp, {'U':recon_batch[:,0:4], 'C':recon_batch[:,4:6]}, ['U0','C']))
            rfreq1 = torch.bincount(recon_disc[:,0], minlength=4)/data.shape[0]
            rfreq2 = torch.bincount(recon_disc[:, 1], minlength=2)/data.shape[0]

            loss = loss_function(recon_batch, data, p)
            loss.backward()
            train_loss += loss.item()
            optimizer.step()


            log_interval= 10
            if batch_idx % log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader),
                    loss.item() / len(data)))


                # ---------Calculates P(C|do(D)) with backdoor -----------
                # onehot_data = get_multiple_labels_fill(Exp, dataset, [U_dim, pC_dim], isImage_labels=False)
                # midC= model.encode(onehot_data)
                # midC = map_dictfill_to_discrete(Exp, {'C':midC}, ['C'])
                # rfreq3 = torch.bincount(midC[:, 0], minlength=4) / data.shape[0]

                #
                # cur_data= np.concatenate([dataset.cpu().numpy(), midC],axis=1)
                # px = pd.DataFrame(cur_data)
                # px = px.rename(columns={0: 'U0', 1: 'D', 2: 'C'})
                # bd_dict = estiamte_ate_backdoor_direct(Exp, px, 'D', 'C', ['U0'])
                # # print("backdoor:")
                # for dict in bd_dict:
                #     print(dict, bd_dict[dict])

        print('====> Epoch: {} Average loss: {:.4f}'.format(
              epoch, train_loss / len(train_loader.dataset)))




def cond_vs_intv(Exp, U0,D,C):
    # ---------Calculates P(C|do(D)) with backdoor -----------
    cur_data= torch.cat([U0,D,C],1).cpu().numpy()
    # div = np.prod([Exp.label_dim[lb] for lb in ['D']])
    cond_prob = compare_conditionals_within(Exp, cur_data[:, 1:3], [label], ['D'], ['D', 'C'])
    print("Conditional:",cond_prob)

    # cur_data= np.concatenate([dataset.cpu().numpy(), midC],axis=1)
    px = pd.DataFrame(cur_data)
    px = px.rename(columns={0: 'U0', 1: 'D', 2: 'C'})
    bd_dict = estiamte_ate_backdoor_direct(Exp, px, 'D', 'C', ['U0'])
    print("backdoor:")
    for dict in bd_dict:
        print(dict, bd_dict[dict])





if __name__ == '__main__':
    # configuration starts
    # lat_dim = 16


    Exp = Experiment("Exp1", set_imageMediator,
    # Exp = Experiment("Exp1", set_mnist_random_graph,
                     dist_thresh=0.15,
                     causal_hierarchy=2,
                     noise_states=64,
                     latent_state=4,
                     new_experiment=False,
                     Synthetic_Sample_Size=20000,
                     intv_Sample_Size=20000,
                     Data_intervs=[{}],
                     allowed_noise=0.05,
                     learning_rate=0.01
                     )


    # bn_dict, INSTANCES = get_bayesian_network(Exp, {}, load_scm=1)

    digit_images = get_Imagedataset(Exp, 0, "ImgYdigit1")

    label='C'
    momentum = 0.5
    classifier= ConditionalClassifier(output_dim=Exp.label_dim[label]).to(Exp.DEVICE)
    optimizer = optim.SGD(classifier.parameters(), lr=Exp.learning_rate,momentum=momentum)

    D = get_dataset(Exp, 'D', 0)
    U0 = get_dataset(Exp, 'U0', 0)
    # pC = get_dataset(Exp, 'pC', 0)
    # C = get_dataset(Exp, 'C', 0)

    generate= True
    if generate:
        ll=20000
        m= int(ll/2)
        sumC1= (U0[0:m] + 5*D[0:m])%3
        sumC2= (0 + 7*D[m:ll])%3
        newC= torch.cat([sumC1, sumC2],0)
    else:
        newC= get_dataset(Exp, 'newC', 0)

    cond_vs_intv(U0, D, newC)
    freq= torch.bincount(newC[:,0].type(torch.LongTensor), minlength=2) / newC.shape[0]
    print("newC marginal", freq)

    # cond_vs_intv(U0, D, C)

    real_U0 = U0
    real_D = D
    real_C = newC
    U0 = real_U0.unsqueeze(2).unsqueeze(3).repeat(1, 1, Exp.IMAGE_SIZE, Exp.IMAGE_SIZE)
    steps=0
    classifier.apply(init_weights)

    real_C= get_multiple_labels_fill(Exp, real_C, [pC_dim], isImage_labels=False)
    while steps<=1000:
        output= classifier(Exp, digit_images, U0)
        # gen_C = map_dictfill_to_discrete(Exp, {'C':output}, ['C'])
        gen_C= output
        # gen_C= torch.tensor(gen_C).to(Exp.DEVICE)

        loss1= F.mse_loss(gen_C[:,0], real_C[:,0], reduction='mean')
        loss2= F.mse_loss(gen_C[:,1], real_C[:,1], reduction='mean')
        loss3= F.mse_loss(gen_C[:,2], real_C[:,2], reduction='mean')
        loss= loss2+loss1 + loss3
        print(f"steps:{steps}, loss: {loss}")
        steps=steps+1

        loss.backward()
        optimizer.step()

        genC= map_dictfill_to_discrete(Exp, {'C':gen_C}, ['C'])
        genC= torch.tensor(genC).to(Exp.DEVICE)
        freq = torch.bincount(genC[:,0].type(torch.LongTensor), minlength=2) / genC.shape[0]
        print("fake freq",freq)
        cond_vs_intv(real_U0, D, genC)

        intvno=0
        label_save_dir = Exp.file_roots + "intv" + str(intvno)
        save_datasets(True, label_save_dir, "feature", {'newC':genC.cpu().numpy()})

    #
    cur_data= np.concatenate([real_U0.cpu().numpy(), real_D.cpu().numpy(), gen_C],1)



    div = np.prod([Exp.label_dim[lb] for lb in ['D']])
    cond_prob = compare_conditionals_within(Exp, cur_data[:,1:3],  [label], ['D'], ['D','C'])
    # print("Conditional:",cond_prob)

    loss=0
    for key in cond_prob:
        loss+= cond_prob[key]

    ret=loss

    # ---------Calculates P(C|do(D)) with backdoor -----------
    px = pd.DataFrame(cur_data)
    px = px.rename(columns={0: 'U0', 1: 'D', 2: 'C'})
    bd_dict = estiamte_ate_backdoor_direct(Exp, px, 'D', 'C', ['U0'])
    # print("backdoor:")
    for dict in bd_dict:
        print(dict, bd_dict[dict])
    #write a loss function

